import torch
from torch import nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import numpy
import math
import argparse
import os
import random
import Model
from SelectDataset import SelectMNIST

parser = argparse.ArgumentParser(description='Small_MNIST')
parser.add_argument('--epochnumber', default=50000, type=int)
parser.add_argument('--initial', default=1, type=int)
parser.add_argument('--mininstance', default=1, type=int)
parser.add_argument('--instance', default=4, type=int)
parser.add_argument('--saveinst', default=0, type=int)
parser.add_argument('--batchsize', default=32, type=int)
parser.add_argument('--printstep', default=60, type=int)
parser.add_argument('--lr', default=1e-2, type=float)
parser.add_argument('--wd', default= 1/32, type=float)
parser.add_argument('--loglr', default=20, type=float)
parser.add_argument('--epochlen', default=54, type=int)
parser.add_argument('--hdim', default=16, type=int)
parser.add_argument('--ch', default=4, type=int)
parser.add_argument('--gpu', default="0",type=str) # "-1" = cpu
parser.add_argument('--fileprefix', default='Equi',type=str)
parser.add_argument('--load', default=None,type=str)
parser.add_argument('--loadepoch', default=0,type=int)
parser.add_argument('--saveepoch', default=100,type=int)
parser.add_argument('--numdigits', default=10, type=int)
parser.add_argument('--sampleperdigit', default=128, type=int)

def Prepare():

    torch.manual_seed(0)

    global args, Criterion, DS_train, InitParaFile, LoadParaFile, DL_train, Inst, MinInst, Criterion, Device

    args = parser.parse_args()

    args.lr = 2**(-args.loglr/4)

    if args.gpu == "-1":
        Device = torch.device("cpu")
    else:
        Device = torch.device("cuda:" + args.gpu)

    if not os.path.exists("./dataset/"):
        os.makedirs("./dataset/")
    if not os.path.exists("./Init_parameters/"):
        os.makedirs("./Init_parameters/")
    if not os.path.exists("./%s_parameters/" % args.fileprefix):
        os.makedirs("./%s_parameters/" % args.fileprefix)
    if not os.path.exists("./%s_results/" % args.fileprefix):
        os.makedirs("./%s_results/" % args.fileprefix)
        
    Inst = args.instance
    MinInst = Inst*args.mininstance-Inst+1

    Criterion=nn.CrossEntropyLoss().to(Device)
    
    DataFile="./dataset/SmallMNIST_numdigits_%d_sampleperdigit_%d.data" % (args.numdigits,args.sampleperdigit)
    if not os.path.exists(DataFile):
        SelectMNIST(args.numdigits,args.sampleperdigit)
    Data = torch.load(DataFile)
    X_train = Data["images"].to(Device)
    y_train = Data["labels"].to(Device)
    DS_train=TensorDataset(X_train,y_train)
    args.epochlen = X_train.size()[0]//args.batchsize    
    DL_train=DataLoader(DS_train, batch_size=args.batchsize, shuffle = True)
    
    LoadParaFile=[]
    for i in range(MinInst,MinInst+Inst):
        InitParaFile="./Init_parameters/InitPara_%d_numdigits_%d.para" % (args.initial,args.numdigits)
        if not os.path.exists(InitParaFile):
            net = Model.Net(args.hdim, args.numdigits, args.ch).to(Device)
            for traincounter in range(100):
                print(Model.Run(net,DL_train,Criterion,args.lr,args.wd,train=True))
            torch.save(net.state_dict(), InitParaFile)
        ParaFile = InitParaFile
        l = -1
        while os.path.exists("./%s_parameters/Para_%s_log_%d_init_%d_epoch_%d_inst_%d.pt" % (args.fileprefix,args.fileprefix,args.loglr,args.initial,l+args.saveepoch,i+args.saveinst)):
            ParaFile="./%s_parameters/Para_%s_log_%d_init_%d_epoch_%d_inst_%d.pt" % (args.fileprefix,args.fileprefix,args.loglr,args.initial,l+args.saveepoch,i+args.saveinst)
            l += args.saveepoch
        LoadParaFile.append((ParaFile,l))
        print("Inst:",i,"Load:",l)
    
Prepare()

def learn(i, LR = args.lr, WD = args.wd, EpochNumber = args.epochnumber, LoadEpoch = args.loadepoch):
  
    net=Model.Net(args.hdim, args.numdigits, args.ch).to(Device)
    net.load_state_dict(torch.load(LoadParaFile[i-MinInst][0]))
    
    Output = torch.empty(0,5)
    weight_Output = torch.empty(0,net.state().size()[0])

    for epochcounter in range(LoadParaFile[i-MinInst][1]+1,EpochNumber):

        TrainLoss, TrainAccu, TrainGrad= 0, 0, 0

        Output0 = torch.empty(0,5)
        for batchcounter,(X,y) in enumerate(DL_train):
            net.train()
            
            L, A, G=Model.run(net,X,y,Criterion,LR,WD,True)
            output0 = torch.tensor([epochcounter, batchcounter, \
                                           L,A,G])
            Output0 = torch.cat((Output0,output0.unsqueeze(0))).detach()

            if epochcounter < 200:
                ParaFile="./%s_parameters/Para_%s_log_%d_init_%d_epoch_%d_batch_%d_inst_%d.pt" % (args.fileprefix,args.fileprefix,args.loglr,args.initial,epochcounter,batchcounter,i)
                torch.save(net.state_dict(), ParaFile)
                OutputFile="./%s_results/Loss_%s_log_%d_init_%d_epoch_%d_batch_%d_inst_%d.pt" % (args.fileprefix,args.fileprefix,args.loglr,args.initial,epochcounter,batchcounter,i)
                torch.save(Output0,OutputFile)
                
            TrainLoss+=L.item()
            TrainAccu+=A.item()
            TrainGrad+=G.item()
       
        TrainLoss /= args.epochlen
        TrainAccu /= args.epochlen
        TrainGrad /= args.epochlen

        output = torch.tensor([epochcounter, LR, \
                               TrainLoss, TrainAccu, TrainGrad])
        weight_output = net.state()
        
        Output = torch.cat((Output,output.unsqueeze(0))).detach()
        weight_Output = torch.cat((weight_Output,weight_output.unsqueeze(0).cpu())).detach()

        if (epochcounter+1) % args.saveepoch == 0:
            OutputFile="./%s_results/Loss_%s_log_%d_init_%d_epoch_%d_inst_%d.pt" % (args.fileprefix,args.fileprefix,args.loglr,args.initial,epochcounter,i)
            weight_OutputFile="./%s_results/Weight_%s_log_%d_init_%d_epoch_%d_inst_%d.pt" % (args.fileprefix,args.fileprefix,args.loglr,args.initial,epochcounter,i)
            torch.save(Output,OutputFile)
            torch.save(weight_Output,weight_OutputFile)

            Output = torch.empty(0,5)
            weight_Output = torch.empty(0,weight_output.size()[0])
            
            ParaFile="./%s_parameters/Para_%s_log_%d_init_%d_epoch_%d_inst_%d.pt" % (args.fileprefix,args.fileprefix,args.loglr,args.initial,epochcounter,i)
            torch.save(net.state_dict(), ParaFile)

            print("initial:", args.initial, "instance:", i, "epoch:", epochcounter,TrainLoss,TrainAccu, TrainGrad)
    
for i in range(MinInst, MinInst+Inst):
    learn(i)